{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "# Batch Segmentation\n",
    "\n",
    "[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/geoai/blob/main/docs/examples/batch_segmentation.ipynb)\n",
    "\n",
    "This notebook demonstrates how to train semantic segmentation models and run batch segmentation.\n",
    "\n",
    "## Install packages\n",
    "\n",
    "To use the new functionality, ensure the required packages are installed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install geoai-py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Import libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import geoai"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Download sample data\n",
    "\n",
    "We'll use the [waterbody dataset](https://www.kaggle.com/datasets/franciscoescobar/satellite-images-of-water-bodies) from Kaggle. You will need to create an account and download the dataset. I have already downloaded the dataset and saved a copy on Hugging Face. Let's download the dataset:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/waterbody-dataset.zip\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out_folder = geoai.download_file(url)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "vscode": {
     "languageId": "raw"
    }
   },
   "source": [
    "## Train semantic segmentation model\n",
    "\n",
    "Now we'll train a semantic segmentation model using the new `train_segmentation_model` function. This function supports various architectures from `segmentation-models-pytorch`:\n",
    "\n",
    "- **Architectures**: `unet`, `unetplusplus` `deeplabv3`, `deeplabv3plus`, `fpn`, `pspnet`, `linknet`, `manet`\n",
    "- **Encoders**: `resnet34`, `resnet50`, `efficientnet-b0`, `mobilenet_v2`, etc.\n",
    "\n",
    "For more details, please refer to the [segmentation-models-pytorch documentation](https://smp.readthedocs.io/en/latest/models.html).\n",
    "\n",
    "Let's train the module using U-Net with ResNet34 encoder:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test train_segmentation_model with automatic size detection\n",
    "geoai.train_segmentation_model(\n",
    "    images_dir=f\"{out_folder}/images\",\n",
    "    labels_dir=f\"{out_folder}/masks\",\n",
    "    output_dir=f\"{out_folder}/unet_models\",\n",
    "    architecture=\"unet\",\n",
    "    encoder_name=\"resnet34\",\n",
    "    encoder_weights=\"imagenet\",\n",
    "    num_channels=3,\n",
    "    num_classes=2,  # background and water\n",
    "    batch_size=8,\n",
    "    num_epochs=30,\n",
    "    learning_rate=0.001,\n",
    "    val_split=0.2,\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluate the model\n",
    "\n",
    "Let's examine the training curves and model performance:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.plot_performance_metrics(\n",
    "    history_path=f\"{out_folder}/unet_models/training_history.pth\",\n",
    "    figsize=(15, 5),\n",
    "    verbose=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://github.com/user-attachments/assets/381ce436-3520-4706-9def-b0a7ae8244ac)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run inference on a single image\n",
    "\n",
    "You can run inference on a new image using the `semantic_segmentation` function. I don't have a new image to test on, so I'll use one of the training images. In reality, you would use your own images not used in training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = 3\n",
    "test_image_path = f\"{out_folder}/images/water_body_{index}.jpg\"\n",
    "ground_truth_path = f\"{out_folder}/masks/water_body_{index}.jpg\"\n",
    "prediction_path = f\"{out_folder}/prediction/water_body_{index}.png\"  # save as png to preserve exact values and avoid compression artifacts\n",
    "model_path = f\"{out_folder}/unet_models/best_model.pth\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run semantic segmentation inference\n",
    "geoai.semantic_segmentation(\n",
    "    input_path=test_image_path,\n",
    "    output_path=prediction_path,\n",
    "    model_path=model_path,\n",
    "    architecture=\"unet\",\n",
    "    encoder_name=\"resnet34\",\n",
    "    num_channels=3,\n",
    "    num_classes=2,\n",
    "    window_size=512,\n",
    "    overlap=256,\n",
    "    batch_size=4,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = geoai.plot_prediction_comparison(\n",
    "    original_image=test_image_path,\n",
    "    prediction_image=prediction_path,\n",
    "    ground_truth_image=ground_truth_path,\n",
    "    titles=[\"Original\", \"Prediction\", \"Ground Truth\"],\n",
    "    figsize=(15, 5),\n",
    "    save_path=f\"{out_folder}/prediction/water_body_{index}_comparison.png\",\n",
    "    show_plot=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![image](https://github.com/user-attachments/assets/00308228-0819-4161-9a35-6a98f4cefa93)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run inference on multiple images\n",
    "\n",
    "First, let's download the test images and masks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/waterbody-dataset-sample.zip\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_dir = geoai.download_file(url)\n",
    "images_dir = f\"{data_dir}/images\"\n",
    "masks_dir = f\"{data_dir}/masks\"\n",
    "predictions_dir = f\"{data_dir}/predictions\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "geoai.semantic_segmentation_batch(\n",
    "    input_dir=images_dir,\n",
    "    output_dir=predictions_dir,\n",
    "    model_path=model_path,\n",
    "    architecture=\"unet\",\n",
    "    encoder_name=\"resnet34\",\n",
    "    num_channels=3,\n",
    "    num_classes=2,\n",
    "    window_size=512,\n",
    "    overlap=256,\n",
    "    batch_size=4,\n",
    "    quiet=True,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "geo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
